package org.geektimes.web.mvc;

import org.apache.commons.lang.StringUtils;
import org.geektimes.context.ClassicComponentContext;
import org.geektimes.web.mvc.controller.Controller;
import org.geektimes.web.mvc.controller.PageController;
import org.geektimes.web.mvc.controller.RestController;

import javax.servlet.RequestDispatcher;
import javax.servlet.ServletConfig;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.HttpMethod;
import javax.ws.rs.Path;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.*;
import java.util.stream.Stream;

import static java.util.Arrays.asList;
import static org.apache.commons.lang.StringUtils.substringAfter;

public class FrontControllerServlet extends HttpServlet {

    /**
     * 请求路径和 Controller 的映射关系缓存
     */
    private Map<String, Controller> controllersMapping = new HashMap<>();

    /**
     * 请求路径和 {@link HandlerMethodInfo} 映射关系缓存
     */
    private Map<String, HandlerMethodInfo> handleMethodInfoMapping = new HashMap<>();


    /**
     * 初始化 Servlet
     *
     * @param servletConfig
     */
    @Override
    public void init(ServletConfig servletConfig) {
        initHandleMethods();
    }

    /**
     * 读取所有的 RestController 的注解元信息 @Path
     * 利用 ServiceLoader 技术（Java SPI）
     */
    private void initHandleMethods() {
        for (Controller controller : ServiceLoader.load(Controller.class)) {
            Class<?> controllerClass = controller.getClass();
            Path pathFromClass = controllerClass.getAnnotation(Path.class);
            Method[] publicMethods = controllerClass.getMethods();
            // 处理方法支持的 HTTP 方法集合
            for (Method method : publicMethods) {
                String requestPath = pathFromClass.value();
                Set<String> supportedHttpMethods = findSupportedHttpMethods(method);
                Path pathFromMethod = method.getAnnotation(Path.class);
                if (pathFromMethod != null) {
                    requestPath += pathFromMethod.value();
                    handleMethodInfoMapping.put(requestPath,
                            new HandlerMethodInfo(requestPath, method, supportedHttpMethods));
                    controllersMapping.put(requestPath, controller);
                }
            }
            ClassicComponentContext.getInstance().registerComponents(controllerClass, controller);
        }
    }

    /**
     * 获取处理方法中标注的 HTTP方法集合
     *
     * @param method 处理方法
     * @return
     */
    private Set<String> findSupportedHttpMethods(Method method) {
        Set<String> supportedHttpMethods = new LinkedHashSet<>();
        for (Annotation annotationFromMethod : method.getAnnotations()) {
            HttpMethod httpMethod = annotationFromMethod.annotationType().getAnnotation(HttpMethod.class);
            if (httpMethod != null) {
                supportedHttpMethods.add(httpMethod.value());
            }
        }

        if (supportedHttpMethods.isEmpty()) {
            supportedHttpMethods.addAll(asList(HttpMethod.GET, HttpMethod.POST,
                    HttpMethod.PUT, HttpMethod.DELETE, HttpMethod.HEAD, HttpMethod.OPTIONS));
        }

        return supportedHttpMethods;
    }

    /**
     * SCWCD
     *
     * @param request
     * @param response
     * @throws ServletException
     * @throws IOException
     */
    @Override
    public void service(HttpServletRequest request, HttpServletResponse response)
            throws ServletException, IOException {
        // 建立映射关系
        // requestURI = /a/hello/world
        String requestURI = request.getRequestURI();
        // contextPath  = /a or "/" or ""
        String servletContextPath = request.getContextPath();
        String prefixPath = servletContextPath;
        // 映射路径（子路径）
        String requestMappingPath = substringAfter(requestURI,
                StringUtils.replace(prefixPath, "//", "/"));
        // 映射到 Controller
        Controller controller = controllersMapping.get(requestMappingPath);

        if (controller == null) {
            response.setStatus(HttpServletResponse.SC_NOT_FOUND);
            return;
        }
        HandlerMethodInfo handlerMethodInfo = handleMethodInfoMapping.get(requestMappingPath);

        try {
            if (handlerMethodInfo != null) {

                String httpMethod = request.getMethod();

                if (!handlerMethodInfo.getSupportedHttpMethods().contains(httpMethod)) {
                    // HTTP 方法不支持
                    response.setStatus(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
                    return;
                }

                if (controller instanceof PageController) {
                    PageController pageController = (PageController) controller;

                    String viewPath = pageController.execute(request, response);

                    if (!viewPath.startsWith("/")) {
                        viewPath = "/" + viewPath;
                    }
                    ServletContext servletContext = request.getServletContext();
                    RequestDispatcher requestDispatcher = servletContext.getRequestDispatcher(viewPath);
                    requestDispatcher.forward(request, response);
                } else if (controller instanceof RestController) {

                    Method handlerMethod = handleMethodInfoMapping.get(requestMappingPath).getHandlerMethod();
                    if (handlerMethod.getParameterCount() > 0) {
                        Parameter[] parameters = handlerMethod.getParameters();
                        Object[] args = Arrays.stream(parameters).map(parameter -> {
                                    if (parameter.getType() == HttpServletRequest.class) {
                                        return request;
                                    } else {
                                        return request.getParameter(parameter.getName());
                                    }
                                }
                        ).toArray();
                        Object result = handlerMethod.invoke(controller, args);
                        response.getWriter().print(result);
                    } else {
                        Object result = handlerMethod.invoke(controller);
                        response.getWriter().print(result);
                    }

                }

            }
        } catch (Throwable throwable) {
            throwable.printStackTrace();
            if (throwable.getCause() instanceof IOException) {
                throw (IOException) throwable.getCause();
            } else {
                throw new ServletException(throwable.getCause());
            }
        }

    }

    private boolean hasHttpServletRequest(Parameter[] parameters) {
        return Stream.of(parameters).anyMatch(parameter -> parameter.getType() == HttpServletRequest.class);

    }


}
